# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import math
from argparse import ArgumentParser
from typing import TYPE_CHECKING

from olive.cli.base import BaseOliveCLICommand, add_logging_options
from olive.common.utils import WeightsFileFormat, save_weights

if TYPE_CHECKING:
    from numpy.typing import NDArray


class ConvertAdaptersCommand(BaseOliveCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        sub_parser = parser.add_parser(
            "convert-adapters",
            help=(
                "Convert lora adapter weights to a file that will be consumed by ONNX models generated by Olive"
                " ExtractedAdapters pass."
            ),
        )
        sub_parser.add_argument(
            "-a",
            "--adapter_path",
            type=str,
            required=True,
            help="Path to the adapters weights saved after peft fine-tuning. Can be a local folder or huggingface id.",
        )
        sub_parser.add_argument(
            "--adapter_format",
            type=str,
            default=WeightsFileFormat.ONNX_ADAPTER,
            choices=[el.value for el in WeightsFileFormat],
            help=f"Format to save the weights in. Default is {WeightsFileFormat.ONNX_ADAPTER}.",
        )
        sub_parser.add_argument(
            "-o",
            "--output_path",
            type=str,
            required=True,
            help="Path to save the exported weights. Will be saved in the `adapter_format` format.",
        )
        sub_parser.add_argument(
            "--dtype",
            type=str,
            default="float32",
            choices=["float32", "float16"],
            help=(
                "Data type to save float adapter weights as. If quantize_int4 is True, this is the data type of the"
                " quantization scales. Default is float32."
            ),
        )
        # int4 quantization options for adapter weights
        sub_parser.add_argument(
            "--quantize_int4",
            action="store_true",
            help="Quantize the adapter weights to int4 using blockwise quantization.",
        )
        sub_parser.add_argument(
            "--int4_block_size",
            type=int,
            default=32,
            choices=[16, 32, 64, 128, 256],
            help="Block size for int4 quantization of adapter weights. Default is 32.",
        )
        sub_parser.add_argument(
            "--int4_quantization_mode",
            type=str,
            default="symmetric",
            choices=["symmetric", "asymmetric"],
            help="Quantization mode for int4 quantization of adapter weights. Default is symmetric.",
        )
        add_logging_options(sub_parser)
        sub_parser.set_defaults(func=ConvertAdaptersCommand)

    def run(self):
        import torch
        from peft import LoraConfig, load_peft_weights

        lora_config = LoraConfig.from_pretrained(self.args.adapter_path)

        if getattr(lora_config, "use_dora", False):
            raise ValueError("DoRA adapters are not supported for export.")

        # compute scaling factor for LoRA
        # use_rslora was only added in peft 0.8.0
        if getattr(lora_config, "use_rslora", False):
            scaling = lora_config.lora_alpha / math.sqrt(lora_config.r)
        else:
            scaling = lora_config.lora_alpha / lora_config.r

        adapter_weights = load_peft_weights(self.args.adapter_path, device="cpu")

        transformed_weights = {}
        float_modules = set()
        quant_modules = set()
        for name, value in adapter_weights.items():
            new_name = name.replace("base_model.model.model", "model")
            # cast to dtype first since some dtypes like bfloat16 are not supported by numpy
            # need to copy since the numpy array is read-only
            float_weight = value.to(getattr(torch, self.args.dtype)).numpy().transpose().copy()
            if "lora_B" in new_name:
                float_weight *= scaling
            if not self.args.quantize_int4:
                transformed_weights[new_name] = float_weight
                float_modules.add(new_name.replace(".weight", ""))
            else:
                weight, scale, zero_point = self.int4_block_quant(
                    float_weight, self.args.int4_block_size, self.args.int4_quantization_mode == "symmetric"
                )
                transformed_weights[new_name.replace(".weight", ".quant.weight")] = weight
                transformed_weights[new_name.replace(".weight", ".quant.scale")] = scale
                if self.args.int4_quantization_mode == "asymmetric":
                    # otherwise it's always 0 and not part of the node inputs
                    transformed_weights[new_name.replace(".weight", ".quant.zero_point")] = zero_point
                quant_modules.add(new_name.replace(".weight", ".quant"))

        output_path = save_weights(transformed_weights, self.args.output_path, self.args.adapter_format)
        print(f"Exported adapter weights to {output_path}")

    @staticmethod
    def int4_block_quant(
        float_weight: "NDArray", block_size: int, is_symmetric: bool
    ) -> tuple["NDArray", "NDArray", "NDArray"]:
        """Quantize a weight tensor to int4."""
        # Only need to quantize the weight tensors directly
        # Not the same as OnnxBlockWiseRtnQuantization pass which quantizes an entire model
        # TODO(jambayk): When ORT 1.18.0 is released, use DefaultWeightOnlyQuantizer.int4_block_quant
        import numpy as np
        from onnxruntime import __version__ as ort_version
        from packaging import version

        if version.parse(ort_version) < version.parse("1.22.0"):
            from onnxruntime.quantization.matmul_4bits_quantizer import quantize_matmul_4bits
        else:
            from onnxruntime.quantization.matmul_nbits_quantizer import quantize_matmul_4bits

        rows, cols = float_weight.shape

        blob_size = block_size // 2
        k_blocks = (rows + block_size - 1) // block_size
        padded_rows = k_blocks * block_size
        pad_len = padded_rows - rows
        if pad_len > 0:
            float_weight = np.pad(float_weight, ((0, pad_len), (0, 0)), "constant")

        # block wise quantization, each block comes from a single column
        packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
        scales = np.zeros((cols * k_blocks), dtype=float_weight.dtype)
        zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8")
        quantize_matmul_4bits(packed, float_weight, scales, zero_point, block_size, cols, rows, is_symmetric)

        return packed, scales, zero_point
